import os
import sys
import time
import torch
import torch.nn.functional as F
import mediapy
import imageio
import trimesh
import warp as wp
import numpy as np
from PIL import Image
from pathlib import Path
from natsort import natsorted
from typing import Optional, List, Tuple
from typing_extensions import Literal
from modules.nclaw.sph import volume_sampling
from modules.d3gs.scene.gaussian_model import GaussianModel
from modules.d3gs.gaussian_renderer import get_rasterizer
from modules.d3gs.utils.binding_utils import (
    gaussian_binding,
    gaussian_binding_with_clip_v1
)
from modules.d3gs.utils.simulation_utils import (
    torch2warp_mat33,
    deform_cov_by_F
)

from modules.d3gs.scene.cameras import MiniCam
from modules.d3gs.utils.graphics_utils import getWorld2View2, getProjectionMatrix, focal2fov, fov2focal


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Logger(object):
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        pass


class Timer(object):
    """Time recorder."""

    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)


def verbose_points(points: torch.Tensor, tag: str=''):
    x_min, x_max = points[:, 0].min(), points[:, 0].max()
    y_min, y_max = points[:, 1].min(), points[:, 1].max()
    z_min, z_max = points[:, 2].min(), points[:, 2].max()
    print(f"{tag}\n  x: [{x_min}, {x_max}]")
    print(f"  y: [{y_min}, {y_max}]")
    print(f"  z: [{z_min}, {z_max}]\n")


def save_video_mediapy(
    frame_dir: Path,
    frame_name: str,
    output_path: Path,
    skip_frame: int = 1,
    fps: int = 30,
    white_bg: bool = False,
):
    np_frames = list()
    image_paths = [i for i in frame_dir.glob(frame_name)]
    image_paths = natsorted(image_paths)[::skip_frame]

    for image_path in image_paths:
        image = Image.open(image_path)
        if image.mode == "RGBA":
            background_color = np.array([1, 1, 1]) if white_bg else np.array([0, 0, 0])
            image_rgba = np.array(image)
            norm_rgba = image_rgba / 255.0
            norm_rgba = norm_rgba[:, :, :3] * norm_rgba[:, :, 3:] + (1 - norm_rgba[:, :, 3:]) * background_color
            image_arr = np.array(norm_rgba*255.0, dtype=np.uint8)
        elif image.mode == "RGB":
            image_arr = np.array(image)
        else:
            raise ValueError(f"Unsupported image mode: {image.mode}")
        np_frames.append(image_arr)
    
    mediapy.write_video(output_path, np_frames, fps=fps, qp=18)
    print(f"Video saved to {output_path} with skip frame {skip_frame} and fps {fps}")


# In case mediapy not works correctly, use imageio to save gifs
def save_gif_imageio(
    frame_dir: Path,
    frame_name: str,
    output_path: Path,
    skip_frame: int = 1,
    fps: int = 30,
    white_bg: bool = False,
    resize: Optional[Tuple[int, int]] = None,
):
    np_frames = list()
    image_paths = [i for i in frame_dir.glob(frame_name)]
    image_paths = natsorted(image_paths)[::skip_frame]

    for image_path in image_paths:
        image = Image.open(image_path)
        if resize is not None:
            # resize to 400x400
            image = image.resize(size=resize)
        if image.mode == "RGBA":
            background_color = np.array([1, 1, 1]) if white_bg else np.array([0, 0, 0])
            image_rgba = np.array(image)
            norm_rgba = image_rgba / 255.0
            norm_rgba = norm_rgba[:, :, :3] * norm_rgba[:, :, 3:] + (1 - norm_rgba[:, :, 3:]) * background_color
            image_arr = np.array(norm_rgba*255.0, dtype=np.uint8)
        elif image.mode == "RGB":
            image_arr = np.array(image)
        else:
            raise ValueError(f"Unsupported image mode: {image.mode}")
        np_frames.append(image_arr)
    
    with imageio.get_writer(output_path, mode='I', fps=fps, loop=0) as writer:
        for frame in np_frames:
            writer.append_data(frame)

    print(f"GIF saved to {output_path} with skip frame {skip_frame} and fps {fps}")


def uniform_sampling(mesh: trimesh.Trimesh, resolution: int) -> np.ndarray:
    bounds = mesh.bounds.copy()
    # mesh.vertices = (mesh.vertices - bounds[0]) / (bounds[1] - bounds[0])
    mesh.vertices = mesh.vertices - bounds[0]
    upper_bound = mesh.vertices.max(0)
    dims = np.linspace(np.zeros(3), upper_bound, resolution).T
    grid = np.stack(np.meshgrid(*dims, indexing='ij'), axis=-1).reshape(-1, 3)
    p_x = grid[mesh.contains(grid)]
    # undo normalization
    p_x = p_x + bounds[0]

    return p_x


def volumetric_sampling(mesh: trimesh.Trimesh, resolution: int, asset_path: Path) -> np.ndarray:
    import pyvista

    bounds = mesh.bounds.copy()
    mesh.vertices = (mesh.vertices - bounds.mean(0)) / (bounds[1] - bounds[0]).max() + 0.5
    cache_obj_path = asset_path / f'temp.obj'
    cache_vtk_path = asset_path / f'temp.vtk'
    mesh.export(cache_obj_path)

    radius = 1.0 / resolution * 0.5
    volume_sampling(cache_obj_path, cache_vtk_path, radius, res=(resolution, resolution, resolution))
    pcd: pyvista.PolyData = pyvista.get_reader(str(cache_vtk_path)).read()
    p_x = np.array(pcd.points).copy()

    # undo normalization
    p_x = (p_x - 0.5) * (bounds[1] - bounds[0]).max() + bounds.mean(0)

    cache_obj_path.unlink(missing_ok=True)
    cache_vtk_path.unlink(missing_ok=True)

    return p_x


def surface_sampling(mesh: trimesh.Trimesh, resolution: int) -> np.ndarray:
    # resolution in this case is the number of points to sample
    points = trimesh.sample.sample_surface_even(mesh, resolution // 2)[0]

    noise = np.random.normal(0, 0.001, points.shape)
    points_n1 = points.copy() + noise

    return np.concatenate([points, points_n1], axis=0)


def get_warp_device(device: torch.device) -> wp.context.Device:
    if device.type == 'cuda':
        return wp.get_device(f'cuda:{device.index}')
    else:
        return wp.get_device('cpu')


@torch.no_grad()
def prepare_simulation_data(
    save_dir: Path,
    kernels_path: Path,
    particles_path: Optional[Path] = None,
    mesh_path: Optional[Path] = None,
    mesh_sample_mode: Literal["uniform", "volumetric", "surface"] = "volumetric",
    mesh_sample_resolution: int = 30,
    sh_degree: int = 3,
    opacity_thres: float = 0.02,
    particles_downsample_factor: int = 3,
    confidence: float = 0.95,
    max_particles: int = 10,
):
    if (
        (save_dir / "kernels.ply").is_file() 
        and (save_dir / "particles.ply").is_file()
        and (save_dir / "bindings.pt").is_file()
    ):
        print("===================================")
        print(f"Data already prepared. Skipping data preparation.\n")

    else:
        print("===================================")
        print(f'Start preparing data for simulation.\n')

        gaussians = GaussianModel(sh_degree)
        gaussians.load_ply(kernels_path.as_posix())

        opacity = gaussians.get_opacity
        retain_flag = opacity.squeeze() > opacity_thres

        print(f'Gaussians after pruning low opacity kernels: {retain_flag.sum()}')

        gaussians.load_ply_with_mask(kernels_path.as_posix(), retain_flag.cpu().numpy())
        gaussians.save_ply((save_dir / "kernels.ply").as_posix())

        if particles_path is not None:
            print(f'Extracting particles from pcd file [{particles_path}] ...')
            particles = trimesh.load(particles_path).vertices
        elif mesh_path is not None:
            print(f'Extracting particles from mesh file [{mesh_path}] ...')
            os.system(f"cp {mesh_path} {save_dir}/mesh{mesh_path.suffix}")
            mesh: trimesh.Trimesh = trimesh.load(mesh_path, force='mesh')
            if not mesh.is_watertight:
                print(f'[**WARNING**] Invalid mesh from [{mesh_path}]: not watertight!')
                print(f'[**WARNING**] Please manually check the sampled particles in case of unexpected results!')
            if mesh_sample_mode == "uniform":
                particles = uniform_sampling(mesh, mesh_sample_resolution)
            elif mesh_sample_mode == "volumetric":
                particles = volumetric_sampling(mesh, mesh_sample_resolution, save_dir)
            elif mesh_sample_mode == "surface":
                particles = surface_sampling(mesh, mesh_sample_resolution)
            else:
                raise ValueError(f"Unsupported mesh sample mode: {mesh_sample_mode}")
        else:
            raise ValueError("Either 'particles_path' or 'mesh_path' must be provided.")
        particles = torch.from_numpy(particles).float().cuda()

        # downsample particles
        rand_idx = torch.randperm(particles.shape[0])
        particles = particles[rand_idx][::particles_downsample_factor]
        particles = particles.contiguous()

        # pre comp binding
        print(f'Pre-compute bindings to find gaussians without particle bindings ...')
        flag_mat_pre = gaussian_binding(gaussians, particles, confidence=confidence)

        num_particles_pre = flag_mat_pre.sum(1)
        has_particles_pre = num_particles_pre > 0

        to_clone_means3D = gaussians.get_xyz[~has_particles_pre].requires_grad_(False)
        print(f'Particles to be added: {to_clone_means3D.shape}')

        particles = torch.cat([particles, to_clone_means3D], dim=0)

        del flag_mat_pre, num_particles_pre, has_particles_pre

        # finalize binding
        print(f'Finalize binding computation ...')
        weight_mat = gaussian_binding_with_clip_v1(
            gaussians, particles,
            confidence=confidence,
            max_particles=max_particles
        )
        # NOTE: The size of the flag mat should not exceed INT_MAX = 2_147_483_647
        assert weight_mat.reshape(-1).shape[0] < torch.iinfo(torch.int32).max
        weight_mat_sparse = weight_mat.to_sparse_coo()
        print("COO: ", weight_mat_sparse.indices().shape)

        num_particles = (weight_mat > 0).sum(1)

        # save data
        particles_path = save_dir / "particles.ply"
        point_np = particles.cpu().numpy()
        point_tr = trimesh.PointCloud(vertices=point_np)
        point_tr.export(particles_path)

        torch.save({
            # NOTE: different pytorch version may have different behavior
            # "bindings" : weight_mat_sparse.cpu(),
            # NOTE: use the following form to save the sparse tensor
            "bindings_ind": weight_mat_sparse.indices().cpu(),
            "bindings_val": weight_mat_sparse.values().cpu(),
            "bindings_size": weight_mat_sparse.size(),
            "n_particles": num_particles.cpu()
        }, save_dir / "bindings.pt")

    print(f'\nData preparation done.')
    print("===================================\n")
    
def init_xvcf(init_data, init_v, torch_device):
    p_x = torch.from_numpy(init_data.pos).float().to(torch_device)
    init_v = torch.from_numpy(init_v).float().to(torch_device)

    if init_v.ndim == 1:
        init_v = init_v.unsqueeze(0).expand(p_x.shape[0], -1)
    elif self._init_v.ndim == 2:
        pass
        
    C = torch.zeros(p_x.shape[0], 3, 3).to(p_x.device)
    F = torch.eye(3).unsqueeze(0).expand(p_x.shape[0], 3, 3).to(p_x.device)
    S = torch.zeros(p_x.shape[0], 3, 3).to(p_x.device)
    
    return p_x, init_v, C, F, S


def camerainfo_to_minicam(
    c2w: np.array,
    intrinsics: np.array,
    image_height: int,
    image_width: int,
    znear: Optional[float] = 0.01,
    zfar: Optional[float] = 100.0,
    trans: Optional[np.array] = [0.0, 0.0, 0.0],
    scale: Optional[float] = 1.0
):
    if c2w.shape[0] == 3: # (3, 4)
        c2w = np.concatenate([c2w, np.array([[0, 0, 0, 1]])], axis=0)
        
    # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
    c2w[:3, 1:3] *= -1

    # get the world-to-camera transform and set R, T
    w2c = np.linalg.inv(c2w)
    R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
    T = w2c[:3, 3]
    
    focalx = intrinsics[0][0]
    focaly = intrinsics[1][1]
    FovX = focal2fov(focalx, image_width)
    FovY = focal2fov(focaly, image_height)
    
    world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
    projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, fovX=FovX, fovY=FovY).transpose(0,1).cuda()
    full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
    #camera_center = world_view_transform.inverse()[3, :3]
    
    minicam = MiniCam(width=image_width, height=image_height, fovy=FovY, fovx=FovX, znear=znear, zfar=zfar, world_view_transform=world_view_transform, full_proj_transform=full_proj_transform)
    
    return minicam
    
def genesis_camera_to_minicam(
    pos: np.array,
    lookat: np.array,
    fov_degrees: float,
    image_height: int,
    image_width: int,
    up: Optional[np.array] = np.array([0.0, 0.0, -1.0]),
    znear: Optional[float] = 0.05,
    zfar: Optional[float] = 100.0,
    trans: Optional[np.array] = np.array([0.0, 0.0, 0.0]),
    scale: Optional[float] = 1.0
):
    """
    Creates a MiniCam instance from camera position, a look-at point, and field of view.

    This function follows the COLMAP/OpenCV camera convention (X right, Y down, Z forward).

    Args:
        pos (np.array): Camera position in world coordinates (e.g., [2, 2, 1.5]).
        lookat (np.array): The point in world coordinates the camera is looking at (e.g., [0, 0, 0.5]).
        fov_degrees (float): The vertical field of view in degrees.
        image_height (int): The height of the image in pixels.
        image_width (int): The width of the image in pixels.
        up (Optional[np.array]): The world's 'up' direction vector. Defaults to [0, 0, 1] (Z-up).
        znear (Optional[float]): The near clipping plane distance. Defaults to 0.01.
        zfar (Optional[float]): The far clipping plane distance. Defaults to 100.0.
        trans (Optional[np.array]): Additional translation applied to the world. Defaults to [0, 0, 0].
        scale (Optional[float]): Additional scale applied to the world. Defaults to 1.0.

    Returns:
        MiniCam: An instance of the MiniCam class with the calculated camera parameters.
    """
    # Ensure inputs are numpy arrays
    pos = np.array(pos)
    lookat = np.array(lookat)
    up = np.array(up)

    # Helper for normalization
    def normalize(v):
        norm = np.linalg.norm(v)
        return v / norm if norm > 0 else v

    # Calculate the camera's coordinate system axes (COLMAP convention)
    # Z-axis (forward) points from the camera to the look-at point.
    z_axis = normalize(lookat - pos)
    # X-axis (right) is the cross product of the world 'up' and the Z-axis.
    x_axis = normalize(np.cross(up, z_axis))
    # Y-axis (down) is the cross product of the Z and X axes.
    # This creates a right-handed coordinate system (X-right, Y-down, Z-forward).
    y_axis = np.cross(z_axis, x_axis)

    # Construct the camera-to-world (c2w) transformation matrix.
    # The rotation part of c2w has the camera's axes as its columns.
    # The translation part is the camera's world position.
    c2w = np.identity(4)
    c2w[:3, :3] = np.stack((x_axis, y_axis, z_axis), axis=1)
    c2w[:3, 3] = pos
    
    # get the world-to-camera transform by inverting c2w
    w2c = np.linalg.inv(c2w)
    
    # Extract R and T in the format expected by getWorld2View2,
    # consistent with the other provided functions.
    # R is the transpose of the rotation part of w2c.
    # T is the translation part of w2c.
    R = np.transpose(w2c[:3,:3])
    T = w2c[:3, 3]
    
    # Calculate field of view (Fov) in radians for X and Y axes.
    # The input fov_degrees is assumed to be the vertical FoV.
    FovY = np.deg2rad(fov_degrees)
    aspect_ratio = image_width / float(image_height)
    FovX = 2 * np.arctan(np.tan(FovY / 2.0) * aspect_ratio)
    
    # The following block is consistent with your other functions.
    world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale), dtype=torch.float32).transpose(0, 1).cuda()
    projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, fovX=FovX, fovY=FovY).transpose(0,1).cuda()
    full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
    
    minicam = MiniCam(
        width=image_width, 
        height=image_height, 
        fovy=FovY, 
        fovx=FovX, 
        znear=znear, 
        zfar=zfar, 
        world_view_transform=world_view_transform, 
        full_proj_transform=full_proj_transform
    )
    
    return minicam



def diff_rasterization(
    x: torch.Tensor,
    deform_grad: Optional[torch.Tensor],
    gaussians: Optional[GaussianModel],
    view_cam,
    background_color: torch.Tensor,
    gaussians_active_sh: Optional[int] = None,
    guassians_cov: Optional[torch.Tensor] = None,
    gaussians_opa: Optional[torch.Tensor] = None,
    gaussians_shs: Optional[torch.Tensor] = None,
    scaling_modifier: Optional[float] = 1.,
    force_mask_data: Optional[bool] = False
) -> torch.Tensor:  
    device = x.device
    means3D = x

    if gaussians is not None:
        cov3D_precomp = gaussians.get_covariance(scaling_modifier=scaling_modifier)
        opacity = gaussians.get_opacity
        shs = gaussians.get_features
        sh_degree = gaussians.active_sh_degree

    else:
        cov3D_precomp = guassians_cov
        opacity = gaussians_opa
        shs = gaussians_shs
        sh_degree = gaussians_active_sh

    assert means3D.shape[0] == cov3D_precomp.shape[0], \
        f"Shape mismatch: means3D {means3D.shape[0]} cov3D {cov3D_precomp.shape[0]}"

    if deform_grad is not None:
        tensor_F = torch.reshape(deform_grad, (-1, 3, 3))
        wp_F = torch2warp_mat33(tensor_F, dvc=device.type)

        assert cov3D_precomp.shape[0] == tensor_F.shape[0], \
            f"Shape mismatch: cov3D {cov3D_precomp.shape[0]} F {tensor_F.shape[0]}"

        wp_cov3D_precomp = wp.from_torch(
            cov3D_precomp.reshape(-1),
            dtype=wp.float32
        )
        wp_cov3D_deformed = wp.zeros_like(wp_cov3D_precomp)
        wp.launch(
            deform_cov_by_F,
            dim=tensor_F.shape[0],
            inputs=[wp_cov3D_precomp, wp_F, wp_cov3D_deformed],
            device=device.type
        )
        wp.synchronize()

        cov3D_deformed = wp.to_torch(wp_cov3D_deformed).reshape(-1, 6)
    else:
        cov3D_deformed = cov3D_precomp

    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    screenspace_points = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass
    means2D = screenspace_points
    
    rasterizer = get_rasterizer(
        view_cam, sh_degree,
        debug=False, bg_color=background_color,
    )

    if force_mask_data:
        # Rasterize visible Gaussians to image.
        rendered_image, radii, rendered_depth, _ = rasterizer(
            means3D=means3D,
            means2D=means2D,
            shs=None,
            colors_precomp=torch.ones(means3D.shape[0], 3, device=device),
            opacities=opacity,
            scales=None,
            rotations=None,
            cov3D_precomp=cov3D_deformed
            # scales=scales,
            # rotations=rotations,
            # cov3D_precomp=None
        )
    else:
        # Rasterize visible Gaussians to image.
        rendered_image, radii, rendered_depth, _ = rasterizer(
            means3D=means3D,
            means2D=means2D,
            shs=shs,
            colors_precomp=None,
            opacities=opacity,
            scales=None,
            rotations=None,
            cov3D_precomp=cov3D_deformed
            # scales=scales,
            # rotations=rotations,
            # cov3D_precomp=None
        )

    return {"render": rendered_image,
            "depth": rendered_depth
            }


def diff_rasterization_mask(
    x: torch.Tensor,
    deform_grad: Optional[torch.Tensor],
    gaussians: Optional[GaussianModel],
    view_cam,
    background_color: torch.Tensor,
    gaussians_active_sh: Optional[int] = None,
    guassians_cov: Optional[torch.Tensor] = None,
    gaussians_opa: Optional[torch.Tensor] = None,
    gaussians_shs: Optional[torch.Tensor] = None,
    scaling_modifier: Optional[float] = 1.,
    force_mask_data: Optional[bool] = False
) -> torch.Tensor:  
    device = x.device
    means3D = x

    if gaussians is not None:
        cov3D_precomp = gaussians.get_covariance(scaling_modifier=scaling_modifier)
        opacity = gaussians.get_opacity
        shs = gaussians.get_features
        sh_degree = gaussians.active_sh_degree

    else:
        cov3D_precomp = guassians_cov
        opacity = gaussians_opa
        shs = gaussians_shs
        sh_degree = gaussians_active_sh

    assert means3D.shape[0] == cov3D_precomp.shape[0], \
        f"Shape mismatch: means3D {means3D.shape[0]} cov3D {cov3D_precomp.shape[0]}"

    if deform_grad is not None:
        tensor_F = torch.reshape(deform_grad, (-1, 3, 3))
        wp_F = torch2warp_mat33(tensor_F, dvc=device.type)

        assert cov3D_precomp.shape[0] == tensor_F.shape[0], \
            f"Shape mismatch: cov3D {cov3D_precomp.shape[0]} F {tensor_F.shape[0]}"

        wp_cov3D_precomp = wp.from_torch(
            cov3D_precomp.reshape(-1),
            dtype=wp.float32
        )
        wp_cov3D_deformed = wp.zeros_like(wp_cov3D_precomp)
        wp.launch(
            deform_cov_by_F,
            dim=tensor_F.shape[0],
            inputs=[wp_cov3D_precomp, wp_F, wp_cov3D_deformed],
            device=device.type
        )
        wp.synchronize()

        cov3D_deformed = wp.to_torch(wp_cov3D_deformed).reshape(-1, 6)
    else:
        cov3D_deformed = cov3D_precomp

    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    screenspace_points = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass
    means2D = screenspace_points
    
    
    rasterizer_mask = get_rasterizer(
        view_cam, sh_degree,
        debug=False, bg_color=torch.tensor([0, 0, 0], dtype=torch.float32, device=device),
    )
    


    # Rasterize visible Gaussians to image. (mask)
    rendered_image_m, _, _, _ = rasterizer_mask(
        means3D=means3D,
        means2D=means2D,
        shs=None,
        colors_precomp=torch.ones(means3D.shape[0], 3, device=device),
        opacities=opacity,
        scales=None,
        rotations=None,
        cov3D_precomp=cov3D_deformed
        # scales=scales,
        # rotations=rotations,
        # cov3D_precomp=None
    )
    
    rasterizer = get_rasterizer(
        view_cam, sh_degree,
        debug=False, bg_color=background_color,
    )
    
    # Rasterize visible Gaussians to image.
    rendered_image, radii, rendered_depth, _ = rasterizer(
        means3D=means3D,
        means2D=means2D,
        shs=shs,
        colors_precomp=None,
        opacities=opacity,
        scales=None,
        rotations=None,
        cov3D_precomp=cov3D_deformed
        # scales=scales,
        # rotations=rotations,
        # cov3D_precomp=None
    )


    return {"render": rendered_image,
            "depth": rendered_depth,
            "mask": rendered_image_m
            }




def compute_bindings_xyz(
    p_curr: torch.Tensor,
    p_prev: torch.Tensor,
    k_prev: torch.Tensor,
    bindings: torch.Tensor,
):
    """Compute updated location of gaussian kernels.

    Args:
        p_curr: Current particles xyz.
        p_prev: Previous particles xyz.
        k_prev: Previous kernels xyz.
        bindings: Binding matrix.
    
    Returns:
        k_curr: Updated kernels xyz.
    """
    delta_x = p_curr - p_prev.detach()

    # calculate means3D
    delta_means3D = torch.sparse.mm(bindings, delta_x)
    delta_means3D = delta_means3D.to_dense()
    k_curr = k_prev.detach() + delta_means3D

    return k_curr


def compute_bindings_F(
    deform_grad: torch.Tensor,
    bindings: torch.Tensor,
):
    """Compute updated deformation gradiant for each gaussian kernel.

    Args:
        deform_grad: Deformation gradient of each particle.
        bindings: Binding matrix.
    
    Returns:
        tensor_F: Deformation gradiant for each gaussian kernel.
    """

    # calculate deformation gradient
    tensor_F = torch.reshape(deform_grad, (-1, 9))
    tensor_F = torch.sparse.mm(bindings, tensor_F)
    tensor_F = tensor_F.to_dense()

    # reshape to (kernels, 3, 3)
    tensor_F = torch.reshape(tensor_F, (-1, 3, 3))
    return tensor_F


def preprocess_for_rasterization(
    obj_gaussians: List[GaussianModel],
    obj_deform_grad: List[torch.Tensor],
    obj_kernels_prev: List[torch.Tensor],
    obj_particles_curr: List[torch.Tensor],
    obj_particles_prev: List[torch.Tensor],
    obj_bindings: List[torch.Tensor],
    obj_scalings: List[float]
):  
    # x, deform_grad, cov, opa, shs
    obj_x = list()
    obj_F = list()
    obj_cov = list()
    obj_opa = list()
    obj_shs = list()

    # compute updated location of gaussian kernels
    for p_curr, p_prev, k_prev, bindings in zip(
        obj_particles_curr, obj_particles_prev, obj_kernels_prev, obj_bindings
    ):
        k_curr = compute_bindings_xyz(p_curr, p_prev, k_prev, bindings)
        obj_x.append(k_curr)

    # compute updated deformation gradiant for each gaussian kernel
    for deform_grad, bindings in zip(obj_deform_grad, obj_bindings):
        tensor_F = compute_bindings_F(deform_grad, bindings)
        obj_F.append(tensor_F)

    for gaussians, scaling in zip(obj_gaussians, obj_scalings):
        obj_cov.append(gaussians.get_covariance(scaling_modifier=scaling))
        obj_opa.append(gaussians.get_opacity)
        obj_shs.append(gaussians.get_features)
    
    out_x = torch.cat(obj_x, dim=0)
    out_F = torch.cat(obj_F, dim=0)
    out_cov = torch.cat(obj_cov, dim=0)
    out_opa = torch.cat(obj_opa, dim=0)
    out_shs = torch.cat(obj_shs, dim=0)

    out_dict = {
        "means3D": out_x,
        "deform_grad": out_F,
        "cov3D": out_cov,
        "opacity": out_opa,
        "shs": out_shs,
        "active_sh_degree": obj_gaussians[0].active_sh_degree
    }

    return out_dict
    
def preprocess_for_rasterization_no_binding(
    obj_gaussians: List[GaussianModel],
    obj_deform_grad: List[torch.Tensor],
    #obj_kernels_prev: List[torch.Tensor],
    obj_particles_curr: List[torch.Tensor],
    #obj_particles_prev: List[torch.Tensor],
    obj_scalings: List[float]
):  
    # x, deform_grad, cov, opa, shs
    obj_x = list()
    obj_F = list()
    obj_cov = list()
    obj_opa = list()
    obj_shs = list()
    
    
    # compute updated location of gaussian kernels
    for p_curr in obj_particles_curr:
        #k_curr = compute_bindings_xyz(p_curr, p_prev, k_prev, bindings)
        k_curr = p_curr
        obj_x.append(k_curr)

    # compute updated deformation gradiant for each gaussian kernel
    for deform_grad in obj_deform_grad:
        #tensor_F = compute_bindings_F(deform_grad, bindings)
        tensor_F = deform_grad
        obj_F.append(tensor_F)
    

    for gaussians, scaling in zip(obj_gaussians, obj_scalings):
        obj_cov.append(gaussians.get_covariance(scaling_modifier=scaling))
        obj_opa.append(gaussians.get_opacity)
        obj_shs.append(gaussians.get_features)
    
    out_x = torch.cat(obj_x, dim=0)
    out_F = torch.cat(obj_F, dim=0)
    out_cov = torch.cat(obj_cov, dim=0)
    out_opa = torch.cat(obj_opa, dim=0)
    out_shs = torch.cat(obj_shs, dim=0)

    out_dict = {
        "means3D": out_x,
        "deform_grad": out_F,
        "cov3D": out_cov,
        "opacity": out_opa,
        "shs": out_shs,
        "active_sh_degree": obj_gaussians[0].active_sh_degree
    }

    return out_dict
    
def save_depth_visualization(depth_tensor, save_dir, file_name, cmap='viridis'):
    """
    Save depth tensor with colormap using torchvision.utils.save_image
    Args:
        depth_tensor (torch.Tensor): 2D tensor [H, W]
        save_dir (str): Output directory path
        file_name (str): Filename without extension
        cmap (str): Matplotlib colormap name
    """
    import torch
    import torchvision
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    
    # Convert tensor to numpy array
    depth_np = depth_tensor.detach().cpu().numpy()
    
    # Normalize to [0,1]
    depth_min = np.min(depth_np)
    depth_max = np.max(depth_np)
    if depth_max - depth_min > 1e-6:
        depth_norm = (depth_np - depth_min) / (depth_max - depth_min)
    else:
        depth_norm = depth_np * 0.0  # Uniform color
    
    # Apply matplotlib colormap
    cmap = plt.get_cmap(cmap)
    colored_image = cmap(depth_norm)[..., :3]  # RGBA -> RGB, shape [H,W,3]
    
    # Convert to CHW tensor
    tensor_rgb = torch.from_numpy(colored_image).permute(2, 0, 1).float()
    
    # Create output directory
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"{file_name}.png")
    
    # Save with torchvision
    torchvision.utils.save_image(
        tensor_rgb,
        save_path,
        normalize=False,  # Already in [0,1] range
        padding=0
    )
    
def interpolate_depth(depth_data):
    """
    Interpolate the depth data to double the number of frames from 200 to 400.
    
    Parameters:
        depth_data: Depth data with shape (frames, H, W).
    
    Returns:
        Interpolated depth data with shape (400, H, W).
    """
    frames, height, width = depth_data.shape
    interpolated_depth = np.zeros((400, height, width), dtype=depth_data.dtype)
    
    # Interpolation process
    for i in range(frames-1):
        interpolated_depth[2 * i] = depth_data[i]
        if i < frames - 1:
            interpolated_depth[2 * i + 1] = (depth_data[i] + depth_data[i + 1]) / 2.0
    return interpolated_depth
    


def get_contour_edges(mask: torch.Tensor, dilate_radius: int = 5) -> torch.Tensor:
    """
    Process 2D mask to extract dilated edges (4D tensor ready for pooling)
    Args:
        mask: (H,W) binary mask (0:background, 1:foreground)
    """
    # Ensure 4D input (batch=1, channel=1, H, W)
    if mask.dim() == 2:
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    elif mask.dim() == 3:
        mask = mask.unsqueeze(1)  # (B,1,H,W)
    
    # Morphological operations with proper 4D tensors
    eroded = F.max_pool2d(mask.float(), kernel_size=3, stride=1, padding=1)
    contour = (mask > 0) & (eroded < 1)
    
    # Edge dilation with 4D convolution
    kernel = torch.ones(1, 1, 2*dilate_radius+1, 2*dilate_radius+1, 
                      device=mask.device)
    edge_mask = F.conv2d(contour.float(), kernel, padding=dilate_radius)
    return edge_mask.squeeze() > 0.5  # Output (H,W) boolean

def handle_3channel_mask(mask: torch.Tensor) -> torch.Tensor:
    """Convert 3-channel mask to single-channel binary"""
    # Case 1: Already single-channel
    if mask.shape[0] == 1:  
        return mask[0] > 0.5
    
    # Case 2: 3-channel RGB mask
    # Find dominant channel (white=foreground)
    max_val, _ = mask.max(dim=0)
    return max_val > 0.5

def rdca_anchor_selection(points0: torch.Tensor,
                         gt_mask: torch.Tensor,
                         gt_depth: torch.Tensor,
                         render_depth: torch.Tensor,
                         edge_dilate: int = 7) -> torch.Tensor:
    """
    Dimension-safe anchor selection
    Args:
        gt_mask: (3,H,W) RGB mask or (1,H,W) single-channel
    """
    # Convert 3-channel mask to 2D
    binary_mask = handle_3channel_mask(gt_mask)
    
    # Get edge regions (H,W)
    edge_mask = get_contour_edges(binary_mask, edge_dilate)
    
    # Validate points
    H, W = render_depth.shape
    points = points0.round().long()
    valid_x = (points[:,0] >= 0) & (points[:,0] < W)
    valid_y = (points[:,1] >= 0) & (points[:,1] < H)
    
    # Depth validity (render_depth is reliable)
    y_coord = points[:,1].clamp(0, H-1)
    x_coord = points[:,0].clamp(0, W-1)
    valid_depth = render_depth[y_coord, x_coord] > 0
    
    # Final anchor = non-edge & valid
    #return ~edge_mask[y_coord, x_coord] & valid_x & valid_y & valid_depth
    return ~edge_mask[y_coord, x_coord] & valid_x & valid_y

def rdca_loss(points0: torch.Tensor,
             points1: torch.Tensor,
             gt_mask: torch.Tensor,
             gt_depth: torch.Tensor,
             render_depth: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    # Global alignment loss
    #print(f"points1", points1)
    #print(f"points0", points0)
    point_loss = F.mse_loss(points1, points0)
    #print(f"point_loss", point_loss)
    
    # Anchor-based loss
    anchor_mask = rdca_anchor_selection(points0, gt_mask, gt_depth, render_depth)
    
    if anchor_mask.any():
        anchor_loss = F.mse_loss(points1[anchor_mask], points0[anchor_mask])
    else:
        anchor_loss = torch.tensor(0.0, device=points0.device)
        
    #print(f"anchor_loss", anchor_loss)
    
    return point_loss, anchor_loss
    
def rdca_loss2(points0: torch.Tensor,
             points1: torch.Tensor,
             gt_mask: torch.Tensor,
             gt_depth: torch.Tensor,
             render_depth: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    # Global alignment loss
    #print(f"points1", points1)
    #print(f"points0", points0)
    point_loss = F.mse_loss(points1, points0)
    #print(f"point_loss", point_loss)
    
    # Anchor-based loss
    anchor_mask = rdca_anchor_selection(points0, gt_mask, gt_depth, render_depth)
    anchor_loss = torch.tensor(0.0, device=points0.device)
    
    if anchor_mask.any():
        # Extract valid anchor points
        anchor_points = points0[anchor_mask]
        num_anchors = anchor_points.size(0)
        patch_size = 5
        radius = patch_size // 2
        
        # Generate sampling grid for local patches
        dx = torch.arange(-radius, radius+1, device=anchor_points.device, dtype=torch.float)
        dy = torch.arange(-radius, radius+1, device=anchor_points.device, dtype=torch.float)
        grid_x, grid_y = torch.meshgrid(dx, dy, indexing='xy')
        grid_offsets = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=-1)  # (25, 2)
        
        # Calculate sampling coordinates
        centers = anchor_points.unsqueeze(1)  # (N, 1, 2)
        samples = centers + grid_offsets.unsqueeze(0)  # (N, 25, 2)
        
        # Normalize coordinates to [-1, 1]
        h, w = gt_depth.shape[-2:]
        samples_normalized = torch.zeros_like(samples)
        samples_normalized[..., 0] = (samples[..., 0] / (w-1)) * 2 - 1  # x coordinate
        samples_normalized[..., 1] = (samples[..., 1] / (h-1)) * 2 - 1  # y coordinate
        sampling_grid = samples_normalized.view(num_anchors, patch_size, patch_size, 2)
        
        # Extract patches using grid_sample
        gt_patches = F.grid_sample(
            gt_depth.unsqueeze(0).unsqueeze(0).expand(num_anchors, -1, -1, -1),
            sampling_grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        ).view(num_anchors, -1)
        
        render_patches = F.grid_sample(
            render_depth.unsqueeze(0).unsqueeze(0).expand(num_anchors, -1, -1, -1),
            sampling_grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        ).view(num_anchors, -1)
        
        # Calculate Spearman correlation
        def _rank(x):
            return torch.argsort(torch.argsort(x, dim=1), dim=1).float()
        
        gt_ranks = _rank(gt_patches)
        render_ranks = _rank(render_patches)
        
        # Compute mean-centered ranks
        gt_centered = gt_ranks - gt_ranks.mean(dim=1, keepdim=True)
        render_centered = render_ranks - render_ranks.mean(dim=1, keepdim=True)
        
        # Calculate correlation coefficients
        cov = (gt_centered * render_centered).mean(dim=1)
        gt_std = gt_centered.std(dim=1, unbiased=False)
        render_std = render_centered.std(dim=1, unbiased=False)
        corr = cov / (gt_std * render_std + 1e-6)
        
        # Filter valid correlations and calculate loss
        valid = (gt_std > 1e-6) & (render_std > 1e-6)
        if valid.any():
            anchor_loss = 1 - corr[valid].mean()
        else:
            anchor_loss = torch.tensor(0.0, device=corr.device)
            
        #print(f"anchor_loss", anchor_loss)
    
    return point_loss, anchor_loss